# pip install websocket-client
import ssl
import wave

import pyaudio
from websocket import ABNF
from websocket import create_connection
from queue import Queue
import threading
import traceback
import json
import time
import numpy as np


# class for recognizer in websocket
class Funasr_websocket_recognizer:
    """
    python asr recognizer lib

    """

    def __init__(
            self,
            host="127.0.0.1",
            port="30035",
            is_ssl=True,
            chunk_size="0, 10, 5",
            chunk_interval=10,
            mode="offline",
            wav_name="default",
    ):
        """
        host: server host ip
        port: server port
        is_ssl: True for wss protocal, False for ws
        """
        try:
            if is_ssl == True:
                ssl_context = ssl.SSLContext()
                ssl_context.check_hostname = False
                ssl_context.verify_mode = ssl.CERT_NONE
                uri = "wss://{}:{}".format(host, port)
                ssl_opt = {"cert_reqs": ssl.CERT_NONE}
            else:
                uri = "ws://{}:{}".format(host, port)
                ssl_context = None
                ssl_opt = None
            self.host = host
            self.port = port

            self.msg_queue = Queue()  # used for recognized result text

            print("connect to url", uri)
            self.websocket = create_connection(uri, ssl=ssl_context, sslopt=ssl_opt)

            self.thread_msg = threading.Thread(
                target=Funasr_websocket_recognizer.thread_rec_msg, args=(self,)
            )
            self.thread_msg.start()
            chunk_size = [int(x) for x in chunk_size.split(",")]
            # stride = int(60 * chunk_size[1] / chunk_interval / 1000 * 16000 * 2)
            # chunk_num = (len(audio_bytes) - 1) // stride + 1

            message = json.dumps(
                {
                    "mode": mode,
                    "chunk_size": chunk_size,
                    "encoder_chunk_look_back": 4,
                    "decoder_chunk_look_back": 1,
                    "chunk_interval": chunk_interval,
                    "wav_name": wav_name,
                    "is_speaking": True,
                }
            )

            self.websocket.send(message)

            print("send json", message)

        except Exception as e:
            print("Exception:", e)
            traceback.print_exc()

    # threads for rev msg
    def thread_rec_msg(self):
        try:
            while True:
                msg = self.websocket.recv()
                if msg is None or len(msg) == 0:
                    continue
                msg = json.loads(msg)

                self.msg_queue.put(msg)
        except Exception as e:
            print("client closed")

    # feed data to asr engine, wait_time means waiting for result until time out
    def feed_chunk(self, chunk, wait_time=0.01):
        try:
            self.websocket.send(chunk, ABNF.OPCODE_BINARY)
            # loop to check if there is a message, timeout in 0.01s
            while True:
                msg = self.msg_queue.get(timeout=wait_time)
                if self.msg_queue.empty():
                    break

            return msg
        except:
            return ""

    def close(self, timeout=1):
        message = json.dumps({"is_speaking": False})
        self.websocket.send(message)
        # sleep for timeout seconds to wait for result
        time.sleep(timeout)
        msg = ""
        while not self.msg_queue.empty():
            msg = self.msg_queue.get()

        self.websocket.close()
        # only resturn the last msg
        return msg


if __name__ == "__main__":

    print("example for Funasr_websocket_recognizer")

    # 配置麦克风参数
    FORMAT = pyaudio.paInt16  # 音频格式
    CHANNELS = 1  # 单声道
    RATE = 16000  # 采样率
    CHUNK = int(60 * 10 / 10 / 1000 * RATE * 2)  # 每帧的大小，根据你的stride计算

    # 初始化PyAudio
    p = pyaudio.PyAudio()

    # 打开麦克风
    stream = p.open(format=FORMAT, channels=CHANNELS,
                    rate=RATE, input=True,
                    frames_per_buffer=CHUNK)

    try:
        # 创建一个recognizer
        rcg = Funasr_websocket_recognizer(
            host="127.0.0.1", port="10095", is_ssl=True, mode="2pass", chunk_size="0,10,5"
        )
        # 循环获取麦克风音频并发送
        print("Start speaking...")
        while True:
            # 读取麦克风数据
            frames = []
            for i in range(0, int(RATE / CHUNK * 5)):  # 假设我们想要5秒的音频
                data = stream.read(CHUNK, exception_on_overflow=False)
                frames.append(data)
            # 将帧合并为单个音频字节
            audio_bytes = b''.join(frames)
            # 发送音频数据
            chunk_num = (len(audio_bytes) - 1) // CHUNK + 1
            for i in range(chunk_num):
                beg = i * CHUNK
                data = audio_bytes[beg: beg + CHUNK]
                text = rcg.feed_chunk(data, wait_time=0.02)
                # if len(text) > 0:
                #     print("text", text,type(text))
                if len(text) > 0 and text['mode'] == '2pass-offline':
                    print(text['text'])
                time.sleep(0.05)
    except KeyboardInterrupt:
        # 用户中断程序（例如，通过按Ctrl+C）
        print("\n程序已中断")
    finally:
        # 关闭麦克风和recognizer
        stream.stop_stream()
        stream.close()
        p.terminate()
        # 获取最后的消息
        text = rcg.close(timeout=3)
        print("Final text", text["text"])

    # wav_path = "F:\夸克下载\李琦语音包.wav"
    # with wave.open(wav_path, "rb") as wav_file:
    #     params = wav_file.getparams()
    #     frames = wav_file.readframes(wav_file.getnframes())
    #     # print(frames)
    #     audio_bytes = bytes(frames)
    #
    # stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
    # chunk_num = (len(audio_bytes) - 1) // stride + 1
    # # create an recognizer
    # rcg = Funasr_websocket_recognizer(
    #     host="127.0.0.1", port="10095", is_ssl=True, mode="2pass", chunk_size="0,10,5"
    # )
    # # loop to send chunk
    # for i in range(chunk_num):
    #
    #     beg = i * stride
    #     data = audio_bytes[beg : beg + stride]
    #     # print(data)
    #     text = rcg.feed_chunk(data, wait_time=0.02)
    #     if len(text) > 0:
    #         print("text", text)
    #     time.sleep(0.05)
    #
    # # get last message
    # text = rcg.close(timeout=3)
    # print("text", text)